// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Factors;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Models;
using Microsoft.ML.Probabilistic.Models.Attributes;

namespace Microsoft.ML.Probabilistic.Algorithms
{
    /// <summary>
    /// Gibbs sampling algorithm - includes block Gibbs sampling  
    /// </summary>
    public class GibbsSampling : AlgorithmBase, IAlgorithm
    {
        public static bool DefaultSideChannels = false;
        public bool UseSideChannels = DefaultSideChannels;

        #region IAlgorithm Members

        public override Delegate GetVariableFactor(bool derived, bool initialised)
        {
            if (derived) return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder>(Factor.DerivedVariableGibbs);
            else return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder>(Factor.VariableGibbs);
        }

        /// <summary>
        /// Gets the suffix for Gibbs Sampling operator methods
        /// </summary>
        /// <param name="factorAttributes"></param>
        /// <returns></returns>
        public override string GetOperatorMethodSuffix(List<ICompilerAttribute> factorAttributes)
        {
            if (factorAttributes.Find(o => o.GetType().IsAssignableFrom(typeof (IsVariableFactor))) != null)
                return "Gibbs";
            else
                return "AverageConditional";
        }

        /// <summary>
        /// Gets the suffix for Gibbs Sampling evidence method
        /// Evidence is not supported or supportable for Gibbs. The message
        /// update methods are marked as unsupported so that an appropriate
        /// error message is generated by the model compiler
        /// </summary>
        /// <param name="factorAttributes"></param>
        /// <returns></returns>
        public override string GetEvidenceMethodName(List<ICompilerAttribute> factorAttributes)
        {
            if (factorAttributes.Find(o => o.GetType().IsAssignableFrom(typeof (IsVariableFactor))) != null)
                return "GibbsEvidence";
            else
                return "LogEvidenceRatio";
            //return "LogAverageFactor";
        }

        /// <summary>
        /// Name of the algorithm
        /// </summary>
        public override string Name
        {
            get { return "GibbsSampling"; }
        }

        /// <summary>
        /// Short name of the algorithm
        /// </summary>
        public override string ShortName
        {
            get { return "Gibbs"; }
        }

        /// <summary>
        /// Gets the operator which converts a message to/from another algorithm
        /// </summary>
        /// <param name="channelType">Type of message</param>
        /// <param name="alg2">The other algorithm</param>
        /// <param name="isFromFactor">True if from, false if to</param>
        /// <param name="args">Where to add arguments of the operator</param>
        /// <returns>A method reference for the operator</returns>
        public override MethodReference GetAlgorithmConversionOperator(Type channelType, IAlgorithm alg2, bool isFromFactor, List<object> args)
        {
            throw new InferCompilerException("Cannot convert from " + Name + " to " + alg2.Name);
        }

        /// <summary>
        /// Get the message prototype in the specified direction
        /// </summary>
        /// <param name="channelInfo">The channel information</param>
        /// <param name="direction">The direction</param>
        /// <param name="marginalPrototypeExpression">The marginal prototype expression</param>
        /// <param name="path">Path name of message</param>
        /// <param name="queryTypes">The set of queries to support.  Only used for marginal channels.</param>
        /// <returns>An expression for the method prototype</returns>
        public override IExpression GetMessagePrototype(
            ChannelInfo channelInfo, MessageDirection direction,
            IExpression marginalPrototypeExpression, string path, IList<QueryType> queryTypes)
        {
            Type t = null;
            Type messTyp = null;
            IExpression mp = null;
            CodeBuilder Builder = CodeBuilder.Instance;

            if (channelInfo.IsMarginal)
            {
                // We want the marginal variable to be a GibbsEstimator over the appropriate
                // distribution type
                if (direction == MessageDirection.Forwards && !UseSideChannels)
                {
                    bool estimateMarginal = false;
                    bool collectSamples = false, collectDistributions = false;
                    foreach (QueryType qt in queryTypes)
                    {
                        if (qt.Name == "Marginal") estimateMarginal = true;
                        else if (qt.Name == "Samples") collectSamples = true;
                        else if (qt.Name == "Conditionals") collectDistributions = true;
                    }
                    Type innermostMessageType = marginalPrototypeExpression.GetExpressionType();
                    Type innermostElementType = Distribution.GetDomainType(innermostMessageType);
                    //t = MessageExpressionTransform.GetDistributionType(channelInfo.varInfo.varType, channelInfo.varInfo.innermostElementType, innermostMessageType, true);
                    t = MessageTransform.GetDistributionType(channelInfo.varInfo.varType, innermostElementType, innermostMessageType, true);
                    messTyp = typeof (GibbsMarginal<,>).MakeGenericType(t, channelInfo.varInfo.varType);
                    mp = Builder.NewObject(
                        messTyp, (t == innermostMessageType) ? marginalPrototypeExpression : Builder.DefaultExpr(t), Quoter.Quote(this.BurnIn), Quoter.Quote(this.Thin),
                        Quoter.Quote(estimateMarginal), Quoter.Quote(collectSamples), Quoter.Quote(collectDistributions));
                }
                else
                    mp = marginalPrototypeExpression;
                return mp;
            }
            else
            {
                // Default is sample
                t = marginalPrototypeExpression.GetExpressionType();
                bool useSample = (path != "Distribution");
                if (useSample)
                {
                    messTyp = Distribution.GetDomainType(t);
                    while (messTyp.IsArray)
                        messTyp = messTyp.GetElementType();
                    mp = Builder.DefaultExpr(messTyp);
                }
                else
                {
                    messTyp = t;
                    mp = marginalPrototypeExpression;
                }
                return mp;
            }
        }

        /// <summary>
        /// Allows the algorithm to modify the attributes on a factor. For example, in Gibbs sampling
        /// different message types are passed depending on the context. This is signalled to the MessageTransform
        /// by attaching a MessagePath attribute to the method invoke expression for the factor.
        /// If the factor is a 'variable' pseudo-factor (UsesEqualsDef) then all incoming variables are
        /// Distributions. Otherwise, incoming messages will depend on the grouping
        /// </summary>
        /// <param name="factorExpression">The factor expression</param>
        /// <param name="factorAttributes">Attribute registry</param>
        public override void ModifyFactorAttributes(IExpression factorExpression, AttributeRegistry<object, ICompilerAttribute> factorAttributes)
        {
            IList<MessagePathAttribute> mpas = factorAttributes.GetAll<MessagePathAttribute>(factorExpression);
            bool isVariable = factorAttributes.Has<IsVariableFactor>(factorExpression);
            if (isVariable) return;

            // Process any Message Path attributes that may have been set by the Group transform
            foreach (MessagePathAttribute mpa in mpas)
            {
                if (mpa.FromDistance >= mpa.ToDistance)
                    mpa.Path = "Distribution";
                else
                    mpa.Path = "CurrentSample";
            }
        }

        /// <summary>
        /// Get the default inference query types for a variable for this algorithm.
        /// </summary>
        public override void ForEachDefaultQueryType(Action<QueryType> action)
        {
            action(QueryTypes.Marginal);
            action(QueryTypes.Samples);
        }

        /// <summary>
        /// Get the query type binding for Gibbs sampling - this is the path to the given query type
        /// relative to the raw marginal type.
        /// </summary>
        /// <param name="qt">The query type</param>
        /// <returns></returns>
        public override string GetQueryTypeBinding(QueryType qt)
        {
            if (UseSideChannels) return null;
            if (qt == QueryTypes.Marginal)
                return "Distribution";
            else if (qt == QueryTypes.Samples)
                return "Samples";
            else if (qt == QueryTypes.Conditionals)
                return "Conditionals";
            else
                return "";
        }

        #endregion

        private int burnIn = 100;

        /// <summary>
        /// The number of samples to discard at the beginning
        /// </summary>
        public int BurnIn
        {
            get { return burnIn; }
            set { burnIn = value; }
        }

        private int thin = 5;

        /// <summary>
        /// Reduction factor when constructing sample and conditional lists
        /// </summary>
        public int Thin
        {
            get { return thin; }
            set { thin = value; }
        }

        private int defaultNumberOfIterations = -1;

        /// <summary>
        /// Default number of iterations for Gibbs sampling
        /// </summary>
        public override int DefaultNumberOfIterations
        {
            get { return (defaultNumberOfIterations < 0) ? burnIn + 2000 : defaultNumberOfIterations; }
            set { defaultNumberOfIterations = value; }
        }
    }
}